/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.amazonaws.services.glue.marketplace.connector.tpcds

import com.teradata.tpcds.Results.constructResults
import com.teradata.tpcds.Session
import com.teradata.tpcds.Table
import com.teradata.tpcds.column.Column
import com.teradata.tpcds.column.ColumnType.Base.{CHAR, DATE, DECIMAL, IDENTIFIER, INTEGER, TIME, VARCHAR}
import org.apache.spark.sql.types.{DataType, DateType, DecimalType, IntegerType, LongType, StringType}

import java.util
import collection.JavaConverters._

/** Utility object which gathers commonly used functionalities
 * working with the TPC-DS datasets, tables, columns and schemas.
 */

object TPCDSUtils {

  /** Return a TPC-DS table from TPC-DS table list.
   *
   * @param tableName The requested table name.
   * @return The TPC-DS table.
   */
  def extractTable(tableName: String): Table = {
    Table
      .getBaseTables
      .asScala
      .filter(t => t.toString == tableName.toUpperCase()).head
  }

  /** Return the iterator of TPC-DS dataset.
   * If the dataset is generated by chunk, returns the partial iterator which is split by chunk.
   *
   * @param table The TPC-DS table.
   * @param scale The scale factor in TPC-DS.
   * @param parallelism The number of concurrency for data generation in parallel.
   * @param chunkNumber Chunked dataset number. This is less than or equal to paralellism.
   * @return The iterator of TPC-DS dataset.
   */
  def generateChunkIterator(table: Table,
                    scale: Int,
                    parallelism: Int,
                    chunkNumber: Int): Iterator[util.List[util.List[String]]] = {
    def constructSession(table: Table,
                         scale: Int,
                         parallelism: Int,
                         chunkNumber: Int): Session =
      Session
        .getDefaultSession
        .withScale(scale)
        .withTable(table)
        .withParallelism(parallelism)
        .withChunkNumber(chunkNumber)
        .withNoSexism(true)

    constructResults(table, constructSession(table, scale, parallelism, chunkNumber))
      .iterator()
      .asScala
  }

  /** Convert the actual values into Scala types.
   *
   *  The type relationships between Spark and Scala are in https://spark.apache.org/docs/2.4.3/sql-reference.html.
   *  @param value The value related to the column.
   *  @param dataType The Spark data type related to the value.
   *  @return The Scala type which is converted from relevant Spark type.
   */
  def convertValueType(value: String, dataType: DataType): Any = {
    dataType match {
      case StringType => value
      case LongType => value.toLong
      case IntegerType => value.toInt
      case DecimalType() => new java.math.BigDecimal(value)
      case DateType => java.sql.Date.valueOf(value)
      case _ => throw new IllegalArgumentException(
        s"The specified type: ${dataType.typeName} is not defined.")
      }
  }

  /** Convert TPC-DS data type to Spark Data types.
   *
   *  All TPC-DS type are in https://github.com/Teradata/tpcds/blob/master/src/main/java/com/teradata/tpcds/column/ColumnType.java#L26.
   *  @param column The TPC-DS column.
   *  @return The Spark type which corresponds to TPC-DS type.
   */
  def convertColumnType(column: Column): DataType = {
    val dataType: DataType = {
      column.getType.getBase match {
        case IDENTIFIER => LongType
        case INTEGER => IntegerType
        case DATE => DateType
        case DECIMAL => {
          val columnType = column.getType
          DecimalType(columnType.getPrecision.get(), columnType.getScale.get())
        }
        case TIME | CHAR | VARCHAR => StringType
        // CHAR will be mapped to java.lang.String.
        // `TIME` type is generated only for `dbgen_version` table. The `TIME` column has the format like "HH:mm:ss",
        // however the format is not compatible with `java.sql.Timestamp`. Then the `TIME` column will be converted to String.
        case _ => throw new IllegalArgumentException(
          "Unsupported TPC-DS type " + column.getName + ":" + column.getType.getBase)
      }
    }
    dataType
  }
}
